# Data

from torchvision import datasets, transforms



class Dataset():
    def __init__(self, data='CIFAR10'):
        self.data = data

    def data_reader(self):
        self.data = 'CIFAR10'
        if self.data == 'CIFAR10':
            print('==> Preparing data..')
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(224),  #first crop the image randomly and then resize it.
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])

            transform_test = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

            trainset = datasets.CIFAR10(
                root='./data', train=True, download=True, transform=transform_train)

            testset = datasets.CIFAR10(
                root='./data', train=False, download=True, transform=transform_test)

        else:
            print('==> Preparing data..')
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(224),  #first crop the image randomly and then resize it.
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
            ])

            transform_test = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
            ])
            trainset = datasets.CIFAR100(
                root='./data', train=True, download=True, transform=transform_train)

            testset = datasets.CIFAR100(
                root='./data', train=False, download=True, transform=transform_test)
        return trainset, testset

